import pygame
import numpy as np
import math
import random
import sys
import time
import traceback
import copy  # For deep copying controllers

# Global variable to track fast mode state
FAST_MODE = False
FAST_MODE_TOGGLE_COOLDOWN_MS = 250
LAST_FAST_MODE_TOGGLE_MS = -FAST_MODE_TOGGLE_COOLDOWN_MS
FAST_MODE_RENDER_EVERY = 5
TRACK_DIFFICULTY = "hard"  # ["easy", "medium", "hard", "random"]


def toggle_fast_mode():
    """Toggle fast mode, ignoring accidental duplicate key events."""
    global FAST_MODE, LAST_FAST_MODE_TOGGLE_MS

    now = pygame.time.get_ticks() if pygame.get_init() else int(time.time() * 1000)
    if now - LAST_FAST_MODE_TOGGLE_MS < FAST_MODE_TOGGLE_COOLDOWN_MS:
        return FAST_MODE

    FAST_MODE = not FAST_MODE
    LAST_FAST_MODE_TOGGLE_MS = now
    print(f"Fast mode {'ON - FPS limit removed' if FAST_MODE else 'OFF - FPS limited to 60'}")
    return FAST_MODE

class NeuralNetworkController:
    def __init__(self, input_size):
        self.input_size = input_size
        self.sensor_count = input_size - 2  # Last 2 inputs are steering and velocity
        self.hidden_size = 10
        self.output_size = 2  # Acceleration and steering
        
        # Initialize weights and biases with Xavier-scale random values
        # Input to hidden layer weights
        scale_ih = np.sqrt(2.0 / (self.input_size + 1 + self.hidden_size))
        self.weights_ih = np.random.uniform(-scale_ih, scale_ih, (self.input_size + 1, self.hidden_size))

        # Hidden to output layer weights
        scale_ho = np.sqrt(2.0 / (self.hidden_size + 1 + self.output_size))
        self.weights_ho = np.random.uniform(-scale_ho, scale_ho, (self.hidden_size + 1, self.output_size))
        
        # LeakyReLU parameter
        self.alpha = 0.01  # Slope for negative inputs
        
        # Store activations for visualization
        self.last_inputs = None
        self.last_hidden_raw = None
        self.last_hidden_activated = None
        self.last_outputs = None
    
    def leaky_relu(self, x):
        """LeakyReLU activation function"""
        return np.maximum(self.alpha * x, x)
    
    def predict(self, sensor_readings, current_steering, current_velocity):
        # Normalize readings
        normalized_readings = [reading / 200 for reading in sensor_readings]
        
        # Add car state inputs (normalized)
        normalized_inputs = normalized_readings.copy()
        normalized_inputs.append(current_steering / (math.pi/3))  # Normalize steering angle
        normalized_inputs.append(current_velocity / 5.0)          # Normalize velocity
        
        # Store the inputs for visualization
        self.last_inputs = normalized_inputs
        
        # Convert to numpy array
        inputs = np.array(normalized_inputs)
        
        # Add bias term (as 1.0)
        inputs_with_bias = np.append(1.0, inputs)
        
        # Hidden layer
        hidden_raw = np.dot(inputs_with_bias, self.weights_ih)
        hidden_activated = self.leaky_relu(hidden_raw)
        
        # Store activations for visualization
        self.last_hidden_raw = hidden_raw
        self.last_hidden_activated = hidden_activated
        
        # Add bias for output layer
        hidden_with_bias = np.append(1.0, hidden_activated)
        
        # Output layer (with tanh activation)
        output_raw = np.dot(hidden_with_bias, self.weights_ho)
        outputs = np.tanh(output_raw)
        
        # Store outputs for visualization
        self.last_outputs = outputs
        
        # Scaling outputs to the expected range
        acceleration = 5 * outputs[0]  # Scale from [-1,1] to [-5,5]
        steering = 5 * outputs[1]      # Scale from [-1,1] to [-5,5]
        
        return acceleration, steering
    
    def mutate(self, mutation_rate=0.3, mutation_amount=0.3):
        """Create a mutated copy of this controller with slightly modified weights"""
        # Create new controller with copies of our weights
        new_controller = NeuralNetworkController(self.input_size)

        # Copy parameters
        new_controller.weights_ih = self.weights_ih.copy()
        new_controller.weights_ho = self.weights_ho.copy()
        new_controller.alpha = self.alpha

        # Mutate hidden layer weights
        mask_ih = np.random.random(self.weights_ih.shape) < mutation_rate
        new_controller.weights_ih += mask_ih * np.random.uniform(-mutation_amount, mutation_amount, self.weights_ih.shape)

        # Mutate output layer weights
        mask_ho = np.random.random(self.weights_ho.shape) < mutation_rate
        new_controller.weights_ho += mask_ho * np.random.uniform(-mutation_amount, mutation_amount, self.weights_ho.shape)

        return new_controller
 
    def visualize_network(self, screen, x, y, width, height):
        """Draw a dramatic visualization of the neural network with accurate input values and value labels"""
        if self.last_hidden_activated is None or self.last_outputs is None or self.last_inputs is None:
            return
        
        # Create small font for value labels
        value_font = pygame.font.Font(None, 14)
        
        # Define coordinates
        input_x = x + 40
        hidden_x = x + width/2
        output_x = x + width - 40
        
        # Input layer (show all 9 inputs)
        input_y_start = y + 30
        input_y_step = (height - 60) / (self.input_size + 1)
        
        # Hidden layer
        hidden_y_start = y + 30
        hidden_y_step = (height - 60) / (self.hidden_size)
        
        # Output layer
        output_y_start = y + height/3
        output_y_step = height/3
        
        # Draw the input neurons with actual input values
        for i in range(self.input_size):
            input_y = input_y_start + i * input_y_step
            
            # Get the actual normalized input value
            input_val = self.last_inputs[i]
            
            # Clip to [0, 1] for coloring - input values might be outside this range
            input_val_clamped = max(0, min(1, input_val))
            
            # Use a more vibrant color scheme for inputs:
            # For sensors (first 7 inputs): White (distant/no obstacle) to Blue (close obstacle)
            # For steering: Red (left) to White (center) to Green (right)
            # For velocity: Purple (slow/backward) to Yellow (fast/forward)
            
            if i < self.sensor_count:  # Distance sensors
                # White (1.0/distant) to Blue (0.0/close)
                input_color = (
                    int(100 + 155 * input_val_clamped),  # R: increase with distance
                    int(100 + 155 * input_val_clamped),  # G: increase with distance
                    255                                  # B: always high for blue tint
                )
            elif i == self.sensor_count:  # Steering input
                # Red (negative/left) to White (0) to Green (positive/right)
                if input_val < 0:
                    # Red for left steering
                    neg_intensity = min(1.0, abs(input_val))
                    input_color = (255, int(255 * (1-neg_intensity)), int(255 * (1-neg_intensity)))
                else:
                    # Green for right steering
                    pos_intensity = min(1.0, input_val)
                    input_color = (int(255 * (1-pos_intensity)), 255, int(255 * (1-pos_intensity)))
            else:  # Velocity input
                # Purple (slow) to Yellow (fast)
                vel_intensity = input_val_clamped
                input_color = (
                    int(255 * vel_intensity),             # R: increase with speed
                    int(255 * vel_intensity),             # G: increase with speed
                    int(255 - 155 * vel_intensity)        # B: decrease with speed
                )
            
            # Make size slightly dependent on value for more visual impact
            radius = 5 + int(abs(input_val_clamped) * 3)
            
            # Draw the input neuron
            pygame.draw.circle(screen, input_color, (int(input_x), int(input_y)), radius)
            
            # Add value label to the left of input neurons
            value_text = value_font.render(f"{input_val:.2f}", True, (200, 200, 200))
            screen.blit(value_text, (int(input_x) - value_text.get_width() - 4, int(input_y) - value_text.get_height()/2))
        
        # Draw the hidden neurons with more dramatic color variation
        for j in range(self.hidden_size):
            hidden_y = hidden_y_start + j * hidden_y_step
            
            # Get activation and normalize for color
            activation = self.last_hidden_activated[j]
            
            # Normalize and add exponential curve to make differences more dramatic
            activation_normalized = (activation + 1) / 2  # Scale from [-1,1] to [0,1]
            activation_normalized = min(1.0, activation_normalized ** 0.5)  # Make differences more visible
            
            # More dramatic color: Purple (inactive) to Bright Yellow (active)
            color = (
                int(255 * activation_normalized),          # R: 0 to 255
                int(100 * activation_normalized),          # G: 0 to 100
                int(255 - 155 * activation_normalized)     # B: 255 to 100
            )
            
            # Make neuron size vary significantly with activation for more drama
            neuron_radius = 6 + int(activation_normalized * 4)
            
            pygame.draw.circle(screen, color, (int(hidden_x), int(hidden_y)), neuron_radius)
            
            # Add value label below hidden neurons (to avoid overlap with connections)
            value_text = value_font.render(f"{activation:.2f}", True, (200, 200, 200))
            screen.blit(value_text, (int(hidden_x) - value_text.get_width()/2, int(hidden_y) + neuron_radius + 2))
        
        # Draw the output neurons with dramatic color changes
        for k in range(self.output_size):
            output_y = output_y_start + k * output_y_step
            
            # Get output value and normalize for color
            output_val = self.last_outputs[k]
            
            # Color scheme: Red (negative) to Green (positive)
            if output_val < 0:
                # Negative output: Red (more intense for more negative)
                neg_intensity = min(1.0, abs(output_val) ** 0.5)  # Exponential for drama
                color = (255, int(255 * (1-neg_intensity)), int(255 * (1-neg_intensity)))
            else:
                # Positive output: Green (more intense for more positive)
                pos_intensity = min(1.0, output_val ** 0.5)  # Exponential for drama
                color = (int(255 * (1-pos_intensity)), 255, int(255 * (1-pos_intensity)))
            
            # Make output neurons larger than others
            pygame.draw.circle(screen, color, (int(output_x), int(output_y)), 10)
            
            # Add value label to the right of output neurons
            value_text = value_font.render(f"{output_val:.2f}", True, (200, 200, 200))
            screen.blit(value_text, (int(output_x) + 14, int(output_y) - value_text.get_height()/2))
            
            # Add a small label to identify the output (Acc or Steer)
            if k == 0:
                type_text = value_font.render("Acc", True, (150, 150, 150))
            else:
                type_text = value_font.render("Str", True, (150, 150, 150))
            screen.blit(type_text, (int(output_x) + 14, int(output_y) + 8))
        
        # Draw connections from inputs to hidden layer
        for i in range(self.input_size):
            input_y = input_y_start + i * input_y_step
            
            for j in range(self.hidden_size):
                hidden_y = hidden_y_start + j * hidden_y_step
                
                # Get weight for this connection (add 1 to i for bias)
                weight = self.weights_ih[i+1, j]  # Skip bias weight
                
                # Determine line color and thickness based on weight
                weight_abs = abs(weight)
                
                # More dramatic thickness variation (1-5 pixels)
                if weight_abs < 0.05:
                    thickness = 1
                elif weight_abs < 0.1:
                    thickness = 2
                elif weight_abs < 0.2:
                    thickness = 3
                elif weight_abs < 0.3:
                    thickness = 4
                else:
                    thickness = 5
                
                # Color based on weight sign and magnitude
                if weight > 0:
                    # Positive weights: green to white
                    weight_normalized = min(1.0, weight * 5)  # Amplify for visibility
                    line_color = (
                        int(100 + 155 * weight_normalized),  # R
                        255,                                  # G - always high for green tint
                        int(100 + 155 * weight_normalized)    # B
                    )
                else:
                    # Negative weights: red to white
                    weight_normalized = min(1.0, abs(weight) * 5)  # Amplify for visibility
                    line_color = (
                        255,                                  # R - always high for red tint
                        int(100 + 155 * weight_normalized),   # G
                        int(100 + 155 * weight_normalized)    # B
                    )
                
                # Only draw significant connections to reduce visual clutter
                if weight_abs > 0.02:  # Threshold for drawing connections
                    pygame.draw.line(screen, line_color, 
                                    (int(input_x), int(input_y)), 
                                    (int(hidden_x), int(hidden_y)), 
                                    thickness)
        
        # Draw connections from hidden to output
        for j in range(self.hidden_size):
            hidden_y = hidden_y_start + j * hidden_y_step
            
            for k in range(self.output_size):
                output_y = output_y_start + k * output_y_step
                
                # Get weight (add 1 to j for bias)
                weight = self.weights_ho[j+1, k]  # Skip bias weight
                
                # Determine line thickness (more dramatic: 1-5 pixels)
                weight_abs = abs(weight)
                
                if weight_abs < 0.05:
                    thickness = 1
                elif weight_abs < 0.1:
                    thickness = 2
                elif weight_abs < 0.2:
                    thickness = 3
                elif weight_abs < 0.3:
                    thickness = 4
                else:
                    thickness = 5
                
                # Color based on weight sign and magnitude
                if weight > 0:
                    # Positive weights: green to white
                    weight_normalized = min(1.0, weight * 5)  # Amplify for visibility
                    line_color = (
                        int(100 + 155 * weight_normalized),  # R
                        255,                                  # G - always high for green tint
                        int(100 + 155 * weight_normalized)    # B
                    )
                else:
                    # Negative weights: red to white
                    weight_normalized = min(1.0, abs(weight) * 5)  # Amplify for visibility
                    line_color = (
                        255,                                  # R - always high for red tint
                        int(100 + 155 * weight_normalized),   # G
                        int(100 + 155 * weight_normalized)    # B
                    )
                
                # Only draw significant connections
                if weight_abs > 0.02:
                    pygame.draw.line(screen, line_color, 
                                    (int(hidden_x), int(hidden_y)), 
                                    (int(output_x), int(output_y)), 
                                    thickness)
                    
# Track representation
class Track:
    def __init__(self, outer_points, inner_points):
        self.outer_points = outer_points
        self.inner_points = inner_points
        self.outer_segments = self._create_segments(outer_points)
        self.inner_segments = self._create_segments(inner_points)
        self.all_segments = self.outer_segments + self.inner_segments
        self.start_finish = (outer_points[0], inner_points[0])
        self.difficulty = "medium"  # Default difficulty
        self.special_features = []  # List of special track features
        # Create checkpoints at regular intervals around the track
        n = len(outer_points)
        self.num_checkpoints = 20
        self.checkpoints = []
        for i in range(self.num_checkpoints):
            idx = int(i * n / self.num_checkpoints)
            self.checkpoints.append((outer_points[idx], inner_points[idx]))
        
    def _create_segments(self, points):
        segments = []
        for i in range(len(points)):
            p1 = points[i]
            p2 = points[(i + 1) % len(points)]
            segments.append((p1, p2))
        return segments
    
    def draw(self, screen):
        # Draw the track based on difficulty
        outer_color = (255, 215, 0)  # Default yellow
        inner_color = (255, 215, 0)
        
        # Color code by difficulty
        if hasattr(self, 'difficulty'):
            if self.difficulty == "easy":
                outer_color = (0, 255, 0)  # Green for easy
                inner_color = (0, 255, 0)
            elif self.difficulty == "medium":
                outer_color = (255, 165, 0)  # Orange for medium
                inner_color = (255, 165, 0)
            elif self.difficulty == "hard":
                outer_color = (255, 0, 0)  # Red for hard
                inner_color = (255, 0, 0)
        
        pygame.draw.lines(screen, outer_color, True, self.outer_points, 2)
        pygame.draw.lines(screen, inner_color, True, self.inner_points, 2)
        pygame.draw.line(screen, (255, 255, 255), 
                         self.start_finish[0], self.start_finish[1], 2)
        
        # Draw special features
        for feature in self.special_features:
            feature_type = feature[0]
            if feature_type == "split_path":
                start_angle, end_angle = feature[1], feature[2]
                
                # Draw indicators for split path
                for i in range(len(self.outer_points)):
                    angle = 2 * math.pi * i / len(self.outer_points)
                    if start_angle <= angle <= end_angle:
                        # Draw a marker for split path section
                        p1 = self.outer_points[i]
                        p2 = self.inner_points[i]
                        pygame.draw.line(screen, (255, 255, 255), p1, p2, 1)

# Car physics and sensing
class Car:
    def __init__(self, x, y, angle=0, color=(255, 0, 0)):
        self.x = x
        self.y = y
        self.prev_x = x
        self.prev_y = y
        self.angle = angle
        self.velocity = 0
        self.steering_angle = 0
        self.sensor_count = 7
        self.sensor_range = 200
        self.laps = 0
        self.alive = True
        self.distance_traveled = 0
        self.time_alive = 0
        self.debug_info = []
        self.total_speed = 0
        self.color = color  # Car color, default is red
        self.next_checkpoint = 1  # Start after checkpoint 0 (start line)
        self.checkpoints_hit = 0
        self.angle_progress = 0.0      # Cumulative angular progress around track
        self.max_angle_progress = 0.0  # High-water mark
        self.prev_track_angle = None   # Set on first update
        
    def get_sensor_readings(self, track):
        readings = []
        sensor_endpoints = []
        
        for i in range(self.sensor_count):
            # Calculate sensor angle (spread sensors in front of car)
            sensor_angle = self.angle - math.pi/2 + i * math.pi/(self.sensor_count-1)
            
            # Ray casting to find distance to walls
            start = (self.x, self.y)
            end = (self.x + math.cos(sensor_angle) * self.sensor_range,
                   self.y + math.sin(sensor_angle) * self.sensor_range)
            
            closest_distance = self.sensor_range
            closest_point = end
            
            for segment in track.all_segments:
                intersection = self._line_intersection(start, end, segment[0], segment[1])
                if intersection:
                    distance = math.sqrt((intersection[0] - self.x)**2 + 
                                         (intersection[1] - self.y)**2)
                    if distance < closest_distance:
                        closest_distance = distance
                        closest_point = intersection
            
            readings.append(closest_distance)
            sensor_endpoints.append(closest_point)
                
        return readings, sensor_endpoints
    
    def _line_intersection(self, line1_start, line1_end, line2_start, line2_end):
        # Line-line intersection calculation
        x1, y1 = line1_start
        x2, y2 = line1_end
        x3, y3 = line2_start
        x4, y4 = line2_end
        
        denominator = (y4-y3)*(x2-x1) - (x4-x3)*(y2-y1)
        if denominator == 0:
            return None
        
        ua = ((x4-x3)*(y1-y3) - (y4-y3)*(x1-x3)) / denominator
        ub = ((x2-x1)*(y1-y3) - (y2-y1)*(x1-x3)) / denominator
        
        if 0 <= ua <= 1 and 0 <= ub <= 1:
            x = x1 + ua * (x2-x1)
            y = y1 + ua * (y2-y1)
            return (x, y)
        
        return None
    
    def update(self, acceleration, steering):
        # Save previous position for lap detection
        self.prev_x = self.x
        self.prev_y = self.y
        
        # Apply acceleration (with reduced sensitivity)
        self.velocity += acceleration * 0.1

        # Add friction
        self.velocity *= 0.97

        # Clamp velocity — high floor forces the car to learn to steer
        self.velocity = max(2.0, min(5.0, self.velocity))
        
        # Update steering angle
        self.steering_factor = 1.0 / (1.0 + abs(self.velocity))
        self.steering_angle += steering * self.steering_factor * 0.3
        
        # Gradually return steering to center
        self.steering_angle *= 0.9
        
        # Clamp steering angle
        self.steering_angle = max(-math.pi/3, min(math.pi/3, self.steering_angle))
        
        # Update angle based on steering and velocity
        self.angle += self.steering_angle * self.velocity * 0.03
        
        # Update position
        self.x += math.cos(self.angle) * self.velocity
        self.y += math.sin(self.angle) * self.velocity
        
        # Calculate distance traveled for fitness
        self.distance_traveled += abs(self.velocity)
        self.time_alive += 1
        
        # Track total speed for average speed calculation
        self.total_speed += abs(self.velocity)
        
        # Store debug info
        self.debug_info = [
            f"Velocity: {self.velocity:.2f}",
            f"Steering: {self.steering_angle:.2f}",
            f"Acc Input: {acceleration:.2f}",
            f"Steer Input: {steering:.2f}"
        ]
    
    def check_collision(self, track):
        # Collision detection with track walls
        car_radius = 8
        
        for segment in track.all_segments:
            p1, p2 = segment
            
            # Find closest point on line segment to car
            line_vec = (p2[0] - p1[0], p2[1] - p1[1])
            line_len = math.sqrt(line_vec[0]**2 + line_vec[1]**2)
            
            if line_len == 0:
                continue
                
            line_unitvec = (line_vec[0] / line_len, line_vec[1] / line_len)
            car_vec = (self.x - p1[0], self.y - p1[1])
            projection = (car_vec[0] * line_unitvec[0] + 
                          car_vec[1] * line_unitvec[1])
            
            closest_point = (
                p1[0] + line_unitvec[0] * max(0, min(line_len, projection)),
                p1[1] + line_unitvec[1] * max(0, min(line_len, projection))
            )
            
            distance = math.sqrt((self.x - closest_point[0])**2 + 
                                (self.y - closest_point[1])**2)
            
            if distance < car_radius:
                self.alive = False
                return True
        
        return False
    
    def check_lap(self, track):
        # Check if car crosses finish line in correct direction
        if self._line_intersection((self.prev_x, self.prev_y), 
                                  (self.x, self.y), 
                                  track.start_finish[0], 
                                  track.start_finish[1]):
            v1 = (self.x - self.prev_x, self.y - self.prev_y)
            v2 = (track.start_finish[1][0] - track.start_finish[0][0], 
                  track.start_finish[1][1] - track.start_finish[0][1])
            
            cross_product = v1[0] * v2[1] - v1[1] * v2[0]
            
            if cross_product > 0:
                self.laps += 1
                self.next_checkpoint = 1  # Reset checkpoint counter on new lap
                return True

        return False

    def update_angle_progress(self, center_x, center_y):
        """Track cumulative angular progress around the track center."""
        track_angle = math.atan2(self.y - center_y, self.x - center_x)
        if self.prev_track_angle is not None:
            delta = track_angle - self.prev_track_angle
            # Wrap to [-pi, pi]
            if delta > math.pi:
                delta -= 2 * math.pi
            elif delta < -math.pi:
                delta += 2 * math.pi
            self.angle_progress += delta
            self.max_angle_progress = max(self.max_angle_progress, self.angle_progress)
        self.prev_track_angle = track_angle

    def check_checkpoints(self, track):
        """Check if car crossed the next checkpoint"""
        if not hasattr(track, 'checkpoints') or self.next_checkpoint >= track.num_checkpoints:
            return
        cp = track.checkpoints[self.next_checkpoint]
        if self._line_intersection((self.prev_x, self.prev_y),
                                   (self.x, self.y), cp[0], cp[1]):
            self.checkpoints_hit += 1
            self.next_checkpoint = (self.next_checkpoint + 1) % track.num_checkpoints


    def draw(self, screen, sensor_endpoints=None, show_sensors=True):
        # Draw car with the specified color
        car_color = self.color
        car_length = 20
        car_width = 10
        
        # Calculate corner points of car
        corners = [
            (self.x + math.cos(self.angle) * car_length/2 - 
             math.sin(self.angle) * car_width/2,
             self.y + math.sin(self.angle) * car_length/2 + 
             math.cos(self.angle) * car_width/2),
            (self.x + math.cos(self.angle) * car_length/2 + 
             math.sin(self.angle) * car_width/2,
             self.y + math.sin(self.angle) * car_length/2 - 
             math.cos(self.angle) * car_width/2),
            (self.x - math.cos(self.angle) * car_length/2 + 
             math.sin(self.angle) * car_width/2,
             self.y - math.sin(self.angle) * car_length/2 - 
             math.cos(self.angle) * car_width/2),
            (self.x - math.cos(self.angle) * car_length/2 - 
             math.sin(self.angle) * car_width/2,
             self.y - math.sin(self.angle) * car_length/2 + 
             math.cos(self.angle) * car_width/2)
        ]
        
        pygame.draw.polygon(screen, car_color, corners)
        
        # Draw direction indicator
        front_x = self.x + math.cos(self.angle) * car_length/2
        front_y = self.y + math.sin(self.angle) * car_length/2
        pygame.draw.line(screen, (255, 255, 0), (self.x, self.y), (front_x, front_y), 2)
        
        # Draw sensors if requested and available
        if show_sensors and sensor_endpoints:
            for endpoint in sensor_endpoints:
                pygame.draw.line(screen, (255, 255, 255), 
                                (self.x, self.y), endpoint, 1)

class HillClimber:
    def __init__(self, track, start_pos, start_angle, input_size=7):
        self.track = track
        self.start_pos = start_pos
        self.start_angle = start_angle
        self.input_size = input_size
        
        # Initialize with neural network controller
        self.best_controller = self.create_default_controller()
        
        # Evaluate the default controller
        self.best_fitness = self.evaluate(self.best_controller)
        print(f"Initial fitness: {self.best_fitness}")
    
    def create_default_controller(self):
        """Create a default controller with random parameters"""
        controller = NeuralNetworkController(self.input_size + 2)
        return controller

    def evaluate(self, controller, max_steps=1000, render=False, debug=False, best_controller=None):
        global FAST_MODE

        # Create the test car (red)
        test_car = Car(self.start_pos[0], self.start_pos[1], self.start_angle, color=(255, 0, 0))

        # Create the best car (green) if a best controller is provided
        best_car = None
        if best_controller:
            best_car = Car(self.start_pos[0], self.start_pos[1], self.start_angle, color=(0, 255, 0))

        # Variables to track current control values
        current_acceleration = 0
        current_steering = 0

        if render:
            screen = pygame.display.get_surface()
            screen_width, screen_height = screen.get_size()
            clock = pygame.time.Clock()
            eval_start_time = time.perf_counter()
            
            # For debugging
            font = pygame.font.Font(None, 24)
            small_font = pygame.font.Font(None, 16)

        steps = 0

        # Continue until both cars have crashed or max steps is reached
        while steps < max_steps and (test_car.alive or (best_car and best_car.alive)):
            # Process test car
            test_readings, test_sensor_endpoints = test_car.get_sensor_readings(self.track)
            
            if test_car.alive:
                # Get control outputs for test car
                test_acceleration, test_steering = controller.predict(
                    test_readings, 
                    test_car.steering_angle,
                    test_car.velocity
                )
                
                # Save current control values
                if test_car.alive:
                    current_acceleration = test_acceleration
                    current_steering = test_steering
                
                # Update test car
                test_car.update(test_acceleration, test_steering)
                test_car.update_angle_progress(400, 300)

                # Check for collision
                test_collision = test_car.check_collision(self.track)
                if test_collision and debug and render:
                    print(f"Test car crashed at step {steps}")
                
                # Check for lap completion
                test_car.check_checkpoints(self.track)
                test_lap_completed = test_car.check_lap(self.track)
                if test_lap_completed and debug and render:
                    print(f"Test car completed lap at step {steps}")
            
            # Process best car if it exists
            best_readings = None
            best_sensor_endpoints = None
            
            if best_car and best_car.alive and best_controller:
                # Get sensor readings for best car
                best_readings, best_sensor_endpoints = best_car.get_sensor_readings(self.track)
                
                # Get control outputs for best car
                best_acceleration, best_steering = best_controller.predict(
                    best_readings, 
                    best_car.steering_angle,
                    best_car.velocity
                )
                
                # Update best car
                best_car.update(best_acceleration, best_steering)
                best_car.update_angle_progress(400, 300)

                # Check for collision
                best_collision = best_car.check_collision(self.track)
                if best_collision and debug and render:
                    print(f"Best car crashed at step {steps}")
                
                # Check for lap completion
                best_car.check_checkpoints(self.track)
                best_lap_completed = best_car.check_lap(self.track)
                if best_lap_completed and debug and render:
                    print(f"Best car completed lap at step {steps}")
            
            # Rendering the network every step is the main fast-mode bottleneck.
            if render and (not FAST_MODE or steps % FAST_MODE_RENDER_EVERY == 0):
                # Check for F key presses during simulation
                for event in pygame.event.get():
                    if event.type == pygame.QUIT:
                        pygame.quit()
                        return 0
                    elif event.type == pygame.KEYDOWN and event.key == pygame.K_f:
                        toggle_fast_mode()
                
                screen.fill((0, 0, 0))
                
                # Draw track, cars and info on the left side (800x600 portion)
                self.track.draw(screen)
                
                # Draw cars with proper z-ordering (best car behind test car)
                if best_car:
                    best_car.draw(screen, best_sensor_endpoints, show_sensors=False)
                
                test_car.draw(screen, test_sensor_endpoints, show_sensors=not FAST_MODE)
                
                elapsed = max(time.perf_counter() - eval_start_time, 1e-6)
                steps_per_second = (steps + 1) / elapsed
                
                # Display info for test car
                info_y = 10
                texts = [
                    f"Test Car (Red) - Velocity: {test_car.velocity:.2f}",
                    f"Steering: {current_steering:.2f}",
                    f"Acceleration: {current_acceleration:.2f}",
                    f"Turn ability: {test_car.steering_factor:.2f}",
                    f"Laps: {test_car.laps}",
                    f"Distance: {test_car.distance_traveled:.0f}",
                    f"Step: {steps}/{max_steps}",
                    f"Steps/sec: {steps_per_second:.0f}",
                    f"Fitness: {self._calculate_fitness(test_car):.2f}",
                    f"Difficulty: {self.track.difficulty.upper()}"
                ]
                
                # Add info for best car if it exists
                if best_car:
                    texts.extend([
                        "",  # Empty line for spacing
                        f"Best Car (Green) - Velocity: {best_car.velocity:.2f}",
                        f"Laps: {best_car.laps}",
                        f"Distance: {best_car.distance_traveled:.0f}",
                        f"Fitness: {self._calculate_fitness(best_car):.2f}",
                    ])
                
                # Add debug info if enabled
                if debug:
                    texts.extend(test_car.debug_info)
                
                # Get text color based on difficulty for the difficulty display
                difficulty_color = (255, 255, 255)  # Default white
                if hasattr(self.track, 'difficulty'):
                    if self.track.difficulty == "easy":
                        difficulty_color = (0, 255, 0)  # Green
                    elif self.track.difficulty == "medium":
                        difficulty_color = (255, 165, 0)  # Orange
                    elif self.track.difficulty == "hard":
                        difficulty_color = (255, 0, 0)  # Red
                
                for i, text in enumerate(texts):
                    # Special color for difficulty text
                    if text.startswith("Difficulty: "):
                        text_surface = font.render(text, True, difficulty_color)
                    elif text.startswith("Test Car"):
                        text_surface = font.render(text, True, (255, 0, 0))  # Red text for test car
                    elif text.startswith("Best Car"):
                        text_surface = font.render(text, True, (0, 255, 0))  # Green text for best car
                    else:
                        text_surface = font.render(text, True, (255, 255, 255))
                    screen.blit(text_surface, (10, info_y))
                    info_y += 25
                
                # Show fast mode indicator
                speed_indicator = font.render("FAST MODE" if FAST_MODE else "", True, (255, 255, 0))
                screen.blit(speed_indicator, (600 - speed_indicator.get_width(), 10))
                
                # Draw neural network visualization on the right side (if applicable)
                if isinstance(controller, NeuralNetworkController):
                    # Draw a border for the NN visualization area
                    pygame.draw.rect(screen, (50, 50, 80), (810, 10, 380, 580), 1)
                    controller.visualize_network(screen, 820, 20, 360, 560)
                
                pygame.display.flip()
                
                # Cap framerate based on fast mode
                if FAST_MODE:
                    clock.tick()  # No FPS limit in fast mode
                else:
                    clock.tick(60)  # Normal 60 FPS limit
            
            steps += 1

        # Calculate fitness of the test car
        fitness = self._calculate_fitness(test_car)
        best_car_fitness = self._calculate_fitness(best_car) if best_car else None

        if render and debug:
            if not test_car.alive:
                print(f"Test car crashed at step {steps}")
            if best_car and not best_car.alive:
                print(f"Best car crashed at step {steps}")
            elif steps >= max_steps:
                print(f"Cars reached step limit ({max_steps})")

        # If car reached the step limit (still alive), return that information
        reached_limit = steps >= max_steps and (test_car.alive or (best_car and best_car.alive))

        if render:
            return (fitness, best_car_fitness, reached_limit)
        return fitness
 
    def _calculate_fitness(self, car):
        # Net angular progress around track center.
        # Cumulative: backward rotation subtracts, so spinning nets ~0.
        # One full lap ≈ 2π ≈ 6.28 → ~628 points.
        fitness = max(0, car.angle_progress) * 100
        return fitness
    
    def optimize(self, iterations=50, mutation_rate=0.3, mutation_amount=0.2):
        for i in range(iterations):
            # Create mutated controller
            mutated_controller = self.best_controller.mutate(mutation_rate, mutation_amount)
            
            # Evaluate fitness
            fitness = self.evaluate(mutated_controller)
            
            # If better, keep it
            if fitness > self.best_fitness:
                self.best_controller = mutated_controller
                self.best_fitness = fitness
                print(f"Iteration {i+1}: New best fitness: {self.best_fitness}")
                
                # Gradually reduce mutation as we improve
                mutation_amount *= 0.99
        
        return self.best_controller

# Calculate dynamic starting position based on track
def calculate_starting_position(track):
    """Calculate a good starting position inside the track"""
    # Find the middle of the start/finish line
    start_x = (track.start_finish[0][0] + track.start_finish[1][0]) / 2
    start_y = (track.start_finish[0][1] + track.start_finish[1][1]) / 2
    
    # Calculate the finish line vector
    finish_line_vector = (
        track.start_finish[1][0] - track.start_finish[0][0],
        track.start_finish[1][1] - track.start_finish[0][1]
    )
    
    # Calculate finish line length and unit vector
    finish_line_length = math.sqrt(finish_line_vector[0]**2 + finish_line_vector[1]**2)
    finish_line_unit = (
        finish_line_vector[0] / finish_line_length,
        finish_line_vector[1] / finish_line_length
    )
    
    # Calculate perpendicular vector (90 degrees clockwise) so the car
    # faces the correct racing direction (clockwise around the track).
    perp_vector = (finish_line_unit[1], -finish_line_unit[0])
    
    # Set start position inward from middle of start line
    # 30 pixels is a good distance to ensure we're inside the track
    start_pos = (
        start_x + perp_vector[0] * 30,
        start_y + perp_vector[1] * 30
    )
    
    # Calculate angle perpendicular to start line (pointing into track)
    start_angle = math.atan2(perp_vector[1], perp_vector[0])
    
    return start_pos, start_angle

# Create a more complex track with challenging features based on difficulty
def create_track(difficulty="random"):
    width, height = 800, 600
    center_x, center_y = width // 2, height // 2
    
    # Randomly choose difficulty if not specified
    if difficulty == "random":
        difficulty = random.choice(["easy", "medium", "hard"])
    
    # print(f"Creating {difficulty} track...")
    
    # Control points for track
    outer_points = []
    inner_points = []
    
    # Number of points (higher for more complex tracks)
    n_points = 60
    if difficulty == "hard":
        n_points = 80  # More points for higher resolution on complex tracks
    
    # Set track parameters based on difficulty
    if difficulty == "easy":
        # Easy track - wider, gentler curves, more regular
        base_radius_outer = random.uniform(250, 280)
        base_radius_inner = random.uniform(170, 200)
        min_track_width = 80  # Wider track is easier to navigate
        
        # Gentle features
        sharp_turn_amp = random.uniform(15, 30)
        sharp_turn_freq = random.uniform(1.5, 2.0)
        sharp_turn_phase = random.uniform(0, 2 * math.pi)
        
        chicane_amp = random.uniform(10, 20)
        chicane_freq = random.uniform(3, 4)
        chicane_phase = random.uniform(0, 2 * math.pi)
        
        hairpin_amp = random.uniform(10, 20)
        hairpin_freq1 = random.uniform(1.2, 1.5)
        hairpin_freq2 = random.uniform(0.8, 1.0)
        hairpin_phase = random.uniform(0, 2 * math.pi)
        
        narrow_amp = random.uniform(10, 15)
        narrow_freq1 = random.uniform(0.8, 1.0)
        narrow_freq2 = random.uniform(0.5, 0.8)
        
        # Very regular track for easy difficulty
        right_side_factor = 0.6  # Less variation overall
        
    elif difficulty == "medium":
        # Medium track - moderate width, some challenges
        base_radius_outer = random.uniform(240, 270)
        base_radius_inner = random.uniform(150, 180)
        min_track_width = 70  # Medium track width
        
        # Moderate features
        sharp_turn_amp = random.uniform(25, 45)
        sharp_turn_freq = random.uniform(2.0, 2.5)
        sharp_turn_phase = random.uniform(0, 2 * math.pi)
        
        chicane_amp = random.uniform(15, 25)
        chicane_freq = random.uniform(4, 6)
        chicane_phase = random.uniform(0, 2 * math.pi)
        
        hairpin_amp = random.uniform(20, 30)
        hairpin_freq1 = random.uniform(1.6, 1.9)
        hairpin_freq2 = random.uniform(1.0, 1.5)
        hairpin_phase = random.uniform(0, 2 * math.pi)
        
        narrow_amp = random.uniform(15, 25)
        narrow_freq1 = random.uniform(0.9, 1.2)
        narrow_freq2 = random.uniform(0.6, 0.9)
        
        # Moderate track regularity
        right_side_factor = 0.4
        
    else:  # hard
        # Hard track - narrower in places, complex features
        base_radius_outer = random.uniform(230, 260)
        base_radius_inner = random.uniform(140, 170)
        min_track_width = 50  # Tighter track in places
        
        # Challenging features
        sharp_turn_amp = random.uniform(40, 60)
        sharp_turn_freq = random.uniform(2.5, 3.2)
        sharp_turn_phase = random.uniform(0, 2 * math.pi)
        
        chicane_amp = random.uniform(25, 40)
        chicane_freq = random.uniform(6, 8)
        chicane_phase = random.uniform(0, 2 * math.pi)
        
        hairpin_amp = random.uniform(35, 50)
        hairpin_freq1 = random.uniform(1.9, 2.2)
        hairpin_freq2 = random.uniform(1.4, 1.7)
        hairpin_phase = random.uniform(0, 2 * math.pi)
        
        narrow_amp = random.uniform(30, 45)
        narrow_freq1 = random.uniform(1.2, 1.6)
        narrow_freq2 = random.uniform(0.8, 1.2)
        
        # Less regular track for hard difficulty
        right_side_factor = 0.3
    
    # Add S-curves for medium and hard difficulties
    s_curve_amp = 0
    s_curve_freq = 0
    s_curve_phase = 0
    if difficulty == "medium":
        s_curve_amp = random.uniform(15, 25)
        s_curve_freq = random.uniform(4, 5)
        s_curve_phase = random.uniform(0, 2 * math.pi)
    elif difficulty == "hard":
        s_curve_amp = random.uniform(30, 40)
        s_curve_freq = random.uniform(5, 7)
        s_curve_phase = random.uniform(0, 2 * math.pi)
    
    # Add decreasing radius turns for hard difficulty
    decreasing_radius_effect = 0
    decreasing_radius_freq = 0
    decreasing_radius_phase = 0
    if difficulty == "hard":
        decreasing_radius_effect = random.uniform(0.1, 0.2)
        decreasing_radius_freq = random.uniform(2, 3)
        decreasing_radius_phase = random.uniform(0, 2 * math.pi)
    
    # Special track features - only for medium and hard difficulties
    special_features = []
    
    # Figure 8 crossover (hard only)
    has_figure_8 = False
    figure_8_position = 0
    if difficulty == "hard" and random.random() < 0.5:  # 50% chance for hard tracks
        has_figure_8 = True
        figure_8_position = random.uniform(0, 2 * math.pi)
    
    # Split path that rejoins (medium and hard)
    has_split_path = False
    split_start = 0
    split_end = 0
    if difficulty in ["medium", "hard"] and random.random() < 0.3:  # 30% chance
        has_split_path = True
        split_start = random.uniform(0, math.pi)
        split_length = random.uniform(math.pi/4, math.pi/2)
        split_end = split_start + split_length
    
    # Variable track width (more extreme in hard)
    width_variation_amp = 0
    if difficulty == "medium":
        width_variation_amp = random.uniform(5, 15)
    elif difficulty == "hard":
        width_variation_amp = random.uniform(15, 30)
    
    # Generate track points
    for i in range(n_points):
        angle = 2 * math.pi * i / n_points
        
        # Reduce feature intensity on the right side (where car starts)
        side_factor = right_side_factor + (1 - right_side_factor) * (1 - math.cos(angle)) / 2
        
        # Add complexity with multiple sine waves of different frequencies and phases
        # Sharp turns with higher amplitude components
        sharp_turn = sharp_turn_amp * side_factor * math.sin(angle * sharp_turn_freq + sharp_turn_phase) 
        
        # Chicanes with higher frequency components
        chicane = chicane_amp * side_factor * math.sin(angle * chicane_freq + chicane_phase)
        
        # S-curves for medium/hard
        s_curve = 0
        if difficulty in ["medium", "hard"]:
            s_curve = s_curve_amp * side_factor * math.sin(angle * s_curve_freq + s_curve_phase) * math.cos(angle * (s_curve_freq/2))
        
        # Additional complexity with asymmetric features
        hairpin_turn = hairpin_amp * side_factor * math.sin(angle * hairpin_freq1 - hairpin_phase) * math.sin(angle * hairpin_freq2)
        
        # Narrow section in one part of the track
        narrow_section = narrow_amp * side_factor * math.sin(angle * narrow_freq1) * math.cos(angle * narrow_freq2)
        
        # Decreasing radius turns for hard
        decreasing_radius = 0
        if difficulty == "hard":
            # Make certain turns get progressively tighter
            decreasing_radius = decreasing_radius_effect * math.sin(angle * decreasing_radius_freq + decreasing_radius_phase)
            decreasing_radius = decreasing_radius * decreasing_radius * 100  # Square it for more pronounced effect
        
        # Figure 8 effect - create a crossover point for hard tracks
        figure_8_effect = 0
        if has_figure_8:
            # Create a "pinch" in the track at a specific position
            figure_8_delta = abs(angle - figure_8_position) % (2 * math.pi)
            if figure_8_delta < 0.3 or figure_8_delta > (2 * math.pi - 0.3):
                figure_8_effect = -30  # Pinch the track inward at the crossover point
        
        # Combine all features
        r_outer = base_radius_outer + sharp_turn + chicane + s_curve + hairpin_turn - abs(narrow_section) * 0.5 + figure_8_effect
        # Make inner track follow but with less extreme features for drivability
        r_inner = base_radius_inner + sharp_turn * 0.7 + chicane * 0.6 + s_curve * 0.7 + hairpin_turn * 0.7 + narrow_section * 0.3 + figure_8_effect
        
        # Apply decreasing radius to both curves if present
        if decreasing_radius > 0:
            r_outer -= decreasing_radius
            r_inner -= decreasing_radius * 0.7
        
        # Variable track width
        if width_variation_amp > 0:
            width_variation = width_variation_amp * math.sin(angle * 3.7 + 1.2) * math.sin(angle * 2.3)
            r_inner += width_variation  # Adjust inner radius to create width variation
        
        # Make sure inner track doesn't cross outer track and maintains minimum width
        if r_inner > r_outer - min_track_width:
            r_inner = r_outer - min_track_width
        
        outer_points.append((
            center_x + r_outer * math.cos(angle),
            center_y + r_outer * math.sin(angle)
        ))
        
        inner_points.append((
            center_x + r_inner * math.cos(angle),
            center_y + r_inner * math.sin(angle)
        ))
    
    # Special case for split paths (currently just visual markers, not actual splits)
    if has_split_path:
        # Mark the split path section - future enhancement could implement actual splits
        # For now we'll just change the track color in this section
        special_features.append(("split_path", split_start, split_end))
    
    # Smooth the track to avoid extremely tight corners (2-pass smoothing)
    outer_points = smooth_track(outer_points)
    inner_points = smooth_track(inner_points)
    
    # Create the track with difficulty info
    track = Track(outer_points, inner_points)
    track.difficulty = difficulty  # Store the difficulty for display
    track.special_features = special_features  # Store any special features
    
    return track

def smooth_track(points):
    """Apply smoothing to track points to avoid tight corners"""
    smoothed = []
    n = len(points)
    
    # Apply a simple moving average
    for i in range(n):
        prev_idx = (i - 1) % n
        next_idx = (i + 1) % n
        
        # Average position with neighbors
        x = (points[prev_idx][0] + points[i][0] + points[next_idx][0]) / 3
        y = (points[prev_idx][1] + points[i][1] + points[next_idx][1]) / 3
        
        smoothed.append((x, y))
    
    return smoothed

def wait_for_key():
    print("Waiting for key press...")
    start_time = time.time()
    while True:
        current_time = time.time()
        if current_time - start_time > 0.5:  # Print a message every 0.5 seconds
            print("Still waiting for key press...")
            start_time = current_time
            
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                print("Quit event received during wait_for_key")
                pygame.quit()
                sys.exit()
            if event.type == pygame.KEYDOWN:
                print(f"Key pressed: {pygame.key.name(event.key)}")
                return
            if event.type == pygame.MOUSEBUTTONDOWN:
                print(f"Mouse button pressed: {event.button}")
                return
        pygame.time.wait(100)  # Short delay to avoid hogging CPU

def handle_events():
    """Process events and return if the program should quit"""
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            return True
        elif event.type == pygame.KEYDOWN:
            if event.key == pygame.K_ESCAPE:
                return True
            elif event.key == pygame.K_f:
                toggle_fast_mode()
    
    return False

# Modify window size and NN display in main function
def main():
    try:
        print("Starting self-driving car simulator with Neural Network controller...")
        # Initialize pygame with larger window
        pygame.init()
        width, height = 1200, 800  # Increased from 800x600 to 1200x800
        screen = pygame.display.set_mode((width, height))
        pygame.display.set_caption("Self-Driving Car with Neural Network")
        clock = pygame.time.Clock()
        
        # Initialize fonts
        font = pygame.font.Font(None, 36)
        small_font = pygame.font.Font(None, 20)
        
        # CONTINUOUS TRAINING WITH TRACK REGENERATION
        iteration_counter = 0
        track_counter = 0
        best_fitness_ever = 0
        current_step_limit = 300
        
        # Track all-time best controller
        all_time_best_controller = None
        
        # Create the first track
        track = create_track(TRACK_DIFFICULTY)
        track_counter += 1
        
        # Calculate dynamic starting position
        start_pos, start_angle = calculate_starting_position(track)
        
        # Create hill climber with dynamic starting position and increased input size (7+2)
        # The 7 original sensor inputs + steering angle + velocity
        hill_climber = HillClimber(track, start_pos, start_angle, input_size=7)
        
        # Store the initial controller as the all-time best
        all_time_best_controller = copy.deepcopy(hill_climber.best_controller)
        
        # Number of tracks to average over for each comparison
        EVAL_TRACKS = 3

        # Using an infinite loop to run forever
        while True:
            # Evaluate both controllers on multiple tracks to get stable fitness
            # First track is rendered; the rest run silently in the background
            tracks = []
            for t in range(EVAL_TRACKS):
                tr = create_track(TRACK_DIFFICULTY)
                sp, sa = calculate_starting_position(tr)
                tracks.append((tr, sp, sa))

            # Use the first track for the rendered display
            track, start_pos, start_angle = tracks[0]
            track_counter += 1
            hill_climber.track = track
            hill_climber.start_pos = start_pos
            hill_climber.start_angle = start_angle
            
            # Show an overview of the current training state
            screen.fill((0, 0, 0))
            
            # Draw track in the left 800x600 portion of the screen
            track.draw(screen)
            overlay = pygame.Surface((800, 600), pygame.SRCALPHA)
            overlay.fill((0, 0, 0, 128))
            screen.blit(overlay, (0, 0))
            
            # Visualize the starting position
            pygame.draw.circle(screen, (0, 255, 0), (int(start_pos[0]), int(start_pos[1])), 8)
            # Draw a line showing the starting angle
            line_end_x = start_pos[0] + math.cos(start_angle) * 30
            line_end_y = start_pos[1] + math.sin(start_angle) * 30
            pygame.draw.line(screen, (0, 255, 0), start_pos, (line_end_x, line_end_y), 2)
            
            # Display training info
            info_text = font.render(f"Track #{track_counter}, Iteration #{iteration_counter+1}", 
                                   True, (255, 255, 255))
            fitness_text = font.render(f"All-time Best: {best_fitness_ever:.2f} | Step Limit: {current_step_limit}",
                                      True, (255, 255, 255))
            
            # Color-code difficulty text based on level
            difficulty_color = (255, 255, 255)
            if track.difficulty == "easy":
                difficulty_color = (0, 255, 0)  # Green
            elif track.difficulty == "medium":
                difficulty_color = (255, 165, 0)  # Orange
            elif track.difficulty == "hard":
                difficulty_color = (255, 0, 0)  # Red
                
            difficulty_text = font.render(f"Difficulty: {track.difficulty.upper()}", 
                                         True, difficulty_color)
                                         
            speed_mode = "FAST MODE ENABLED" if FAST_MODE else "Normal Speed"
            speed_text = font.render(f"Speed: {speed_mode}", True, (255, 255, 0) if FAST_MODE else (200, 200, 200))
            
            # Display controller type
            controller_type = "Neural Network" if isinstance(hill_climber.best_controller, NeuralNetworkController) else "Polynomial"
            controller_text = font.render(f"Controller: {controller_type}", True, (200, 200, 255))
            
            if all_time_best_controller:
                all_time_text = font.render(f"All-time Best Fitness: {best_fitness_ever:.2f}", True, (0, 255, 0))
                screen.blit(all_time_text, (400 - all_time_text.get_width()//2, 190))
            
            continue_text = small_font.render("Press ESC to exit, F to toggle fast mode, any other key to jump ahead 20 iterations", 
                                            True, (200, 200, 200))
            
            # Position text in the left part of the screen
            screen.blit(info_text, (400 - info_text.get_width()//2, 30))
            screen.blit(fitness_text, (400 - fitness_text.get_width()//2, 70))
            screen.blit(difficulty_text, (400 - difficulty_text.get_width()//2, 110))
            screen.blit(speed_text, (400 - speed_text.get_width()//2, 150))
            screen.blit(controller_text, (400 - controller_text.get_width()//2, 230))
            screen.blit(continue_text, (400 - continue_text.get_width()//2, 550))
            
            # Draw neural network in the right side of the screen with larger size
            if isinstance(hill_climber.best_controller, NeuralNetworkController):
                # Draw a border for the NN visualization area
                pygame.draw.rect(screen, (50, 50, 80), (810, 10, 380, 580), 1)
                hill_climber.best_controller.visualize_network(screen, 820, 20, 360, 560)
            
            pygame.display.flip()
            
            # Check for events (F key for fast mode)
            should_quit = handle_events()
            if should_quit:
                pygame.quit()
                sys.exit()
            
            # Create a mutated controller
            mutated_controller = hill_climber.best_controller.mutate()
            
            # Evaluate on rendered track (first one)
            result = hill_climber.evaluate(
                mutated_controller,
                max_steps=current_step_limit,
                render=True,
                debug=False,
                best_controller=hill_climber.best_controller
            )
            mutant_rendered_fitness, best_rendered_fitness, reached_limit = result
            if best_rendered_fitness is None:
                best_rendered_fitness = 0

            if reached_limit and current_step_limit < 2000:
                current_step_limit += 100
                print(f"Car reached step limit! Increasing to {current_step_limit} steps")

            # Evaluate on remaining tracks (silent) and average all
            mutant_total = mutant_rendered_fitness
            best_total = best_rendered_fitness
            for tr, sp, sa in tracks[1:]:
                hill_climber.track = tr
                hill_climber.start_pos = sp
                hill_climber.start_angle = sa
                mutant_total += hill_climber.evaluate(mutated_controller, max_steps=current_step_limit)
                best_total += hill_climber.evaluate(hill_climber.best_controller, max_steps=current_step_limit)

            # Restore the rendered track for display
            hill_climber.track, hill_climber.start_pos, hill_climber.start_angle = tracks[0]

            fitness = mutant_total / EVAL_TRACKS
            best_fitness_avg = best_total / EVAL_TRACKS

            # Mutant must beat the current best averaged across all tracks
            improvement = False
            if fitness > best_fitness_avg:
                improvement = True
                hill_climber.best_controller = mutated_controller
                print(f"Iteration {iteration_counter+1}: Mutant beat best (avg {fitness:.2f} > {best_fitness_avg:.2f} over {EVAL_TRACKS} tracks)")
                
                # Neural network controller - just print summary
                print(f"Neural network with {mutated_controller.hidden_size} hidden neurons updated.")
                # Optionally print some network stats
                ih_avg = np.mean(np.abs(mutated_controller.weights_ih))
                ho_avg = np.mean(np.abs(mutated_controller.weights_ho))
                print(f"Avg weight magnitudes - Input→Hidden: {ih_avg:.4f}, Hidden→Output: {ho_avg:.4f}")
                
                # Update the all-time best fitness and controller if necessary
                if fitness > best_fitness_ever:
                    best_fitness_ever = fitness
                    all_time_best_controller = copy.deepcopy(mutated_controller)
                    print(f"New all-time best fitness: {best_fitness_ever}!")
            
            # Show brief result flash
            screen.fill((0, 0, 0))
            track.draw(screen)
            overlay = pygame.Surface((800, 600), pygame.SRCALPHA)
            overlay.fill((0, 0, 0, 128))
            screen.blit(overlay, (0, 0))
            
            if improvement:
                result_text = font.render(f"IMPROVED! Avg fitness: {fitness:.2f} > {best_fitness_avg:.2f}",
                                         True, (0, 255, 0))
            else:
                result_text = font.render(f"No improvement. Avg: {fitness:.2f} vs {best_fitness_avg:.2f}",
                                         True, (255, 100, 100))
            
            screen.blit(result_text, (400 - result_text.get_width()//2, 300))
            
            # Continue to display NN on the right side during result flash
            if isinstance(hill_climber.best_controller, NeuralNetworkController):
                pygame.draw.rect(screen, (50, 50, 80), (810, 10, 380, 580), 1)
                hill_climber.best_controller.visualize_network(screen, 820, 20, 360, 560)
            
            pygame.display.flip()
            
            # Brief pause to see result
            if not FAST_MODE:
                pygame.time.wait(500)  # Just half a second pause
            
            # Check for user input for jumping ahead
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    pygame.quit()
                    sys.exit()
                elif event.type == pygame.KEYDOWN:
                    if event.key == pygame.K_ESCAPE:
                        # Exit if user presses escape
                        print("User exited training loop")
                        pygame.quit()
                        sys.exit()
                    elif event.key == pygame.K_f:
                        toggle_fast_mode()
                    elif event.key not in [pygame.K_f, pygame.K_ESCAPE]:
                        # Skip ahead 20 iterations if user presses any other key (except F or Escape)
                        iteration_counter += 19  # Will be incremented again below
                        print("Jumping ahead 20 iterations")
            
            # Increment counters
            iteration_counter += 1
        
    except Exception as e:
        print(f"Error occurred: {e}")
        traceback.print_exc()
        pygame.quit()
        sys.exit(1)

main()
